Register classes for UDT deserialization
This commit is contained in:
@@ -334,6 +334,11 @@ class Cluster(object):
|
||||
_prepared_statements = None
|
||||
_prepared_statement_lock = None
|
||||
|
||||
_user_types = None
|
||||
"""
|
||||
A map of {keyspace: {type_name: UserType}}
|
||||
"""
|
||||
|
||||
_listeners = None
|
||||
_listener_lock = None
|
||||
|
||||
@@ -412,6 +417,8 @@ class Cluster(object):
|
||||
self._prepared_statements = WeakValueDictionary()
|
||||
self._prepared_statement_lock = Lock()
|
||||
|
||||
self._user_types = defaultdict(dict)
|
||||
|
||||
self._min_requests_per_connection = {
|
||||
HostDistance.LOCAL: DEFAULT_MIN_REQUESTS,
|
||||
HostDistance.REMOTE: DEFAULT_MIN_REQUESTS
|
||||
@@ -444,6 +451,9 @@ class Cluster(object):
|
||||
self.control_connection = ControlConnection(
|
||||
self, self.control_connection_timeout)
|
||||
|
||||
def register_type_class(self, keyspace, user_type, klass):
|
||||
self._user_types[keyspace][user_type] = klass
|
||||
|
||||
def get_min_requests_per_connection(self, host_distance):
|
||||
return self._min_requests_per_connection[host_distance]
|
||||
|
||||
@@ -517,6 +527,7 @@ class Cluster(object):
|
||||
kwargs_dict['ssl_options'] = self.ssl_options
|
||||
kwargs_dict['cql_version'] = self.cql_version
|
||||
kwargs_dict['protocol_version'] = self.protocol_version
|
||||
kwargs_dict['user_type_map'] = self._user_types
|
||||
|
||||
return kwargs_dict
|
||||
|
||||
|
||||
@@ -155,12 +155,14 @@ class Connection(object):
|
||||
is_defunct = False
|
||||
is_closed = False
|
||||
lock = None
|
||||
user_type_map = None
|
||||
|
||||
is_control_connection = False
|
||||
|
||||
def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
|
||||
ssl_options=None, sockopts=None, compression=True,
|
||||
cql_version=None, protocol_version=2, is_control_connection=False):
|
||||
cql_version=None, protocol_version=2, is_control_connection=False,
|
||||
user_type_map=None):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.authenticator = authenticator
|
||||
@@ -170,6 +172,7 @@ class Connection(object):
|
||||
self.cql_version = cql_version
|
||||
self.protocol_version = protocol_version
|
||||
self.is_control_connection = is_control_connection
|
||||
self.user_type_map = user_type_map
|
||||
self._push_watchers = defaultdict(set)
|
||||
if protocol_version >= 3:
|
||||
self._header_unpack = v3_header_unpack
|
||||
@@ -347,7 +350,8 @@ class Connection(object):
|
||||
else:
|
||||
raise ProtocolError("Got negative body length: %r" % body_len)
|
||||
|
||||
response = decode_response(given_version, stream_id, flags, opcode, body, self.decompressor)
|
||||
response = decode_response(given_version, self.user_type_map, stream_id,
|
||||
flags, opcode, body, self.decompressor)
|
||||
except Exception as exc:
|
||||
log.exception("Error decoding response from Cassandra. "
|
||||
"opcode: %04x; message contents: %r", opcode, msg)
|
||||
@@ -386,10 +390,10 @@ class Connection(object):
|
||||
if isinstance(options_response, ConnectionException):
|
||||
raise options_response
|
||||
else:
|
||||
log.error("Did not get expected SupportedMessage response; " \
|
||||
log.error("Did not get expected SupportedMessage response; "
|
||||
"instead, got: %s", options_response)
|
||||
raise ConnectionException("Did not get expected SupportedMessage " \
|
||||
"response; instead, got: %s" \
|
||||
raise ConnectionException("Did not get expected SupportedMessage "
|
||||
"response; instead, got: %s"
|
||||
% (options_response,))
|
||||
|
||||
log.debug("Received options response on new connection (%s) from %s",
|
||||
|
||||
@@ -778,14 +778,21 @@ class UserDefinedType(_ParameterizedType):
|
||||
|
||||
FIELD_LENGTH = 4
|
||||
|
||||
_cache = {}
|
||||
|
||||
@classmethod
|
||||
def apply_parameters(cls, subtypes, names):
|
||||
newname = subtypes[1].cassname.decode("hex")
|
||||
field_names = [encoded_name.decode("hex") for encoded_name in names[2:]]
|
||||
return type(newname, (cls,), {'subtypes': subtypes[2:],
|
||||
'cassname': cls.cassname,
|
||||
'typename': newname,
|
||||
'fieldnames': field_names})
|
||||
def apply_parameters(cls, udt_name, names_and_types, mapped_class):
|
||||
try:
|
||||
return cls._cache[udt_name]
|
||||
except KeyError:
|
||||
fieldnames, types = zip(*names_and_types)
|
||||
instance = type(udt_name, (cls,), {'subtypes': types,
|
||||
'cassname': cls.cassname,
|
||||
'typename': udt_name,
|
||||
'fieldnames': fieldnames,
|
||||
'mapped_class': mapped_class})
|
||||
cls._cache[udt_name] = instance
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def cql_parameterized_type(cls):
|
||||
@@ -795,8 +802,7 @@ class UserDefinedType(_ParameterizedType):
|
||||
def deserialize_safe(cls, byts, protocol_version):
|
||||
proto_version = max(3, protocol_version)
|
||||
p = 0
|
||||
Result = namedtuple(cls.typename, cls.fieldnames)
|
||||
result = []
|
||||
values = []
|
||||
for col_type in cls.subtypes:
|
||||
if p == len(byts):
|
||||
break
|
||||
@@ -806,13 +812,17 @@ class UserDefinedType(_ParameterizedType):
|
||||
p += itemlen
|
||||
# collections inside UDTs are always encoded with at least the
|
||||
# version 3 format
|
||||
result.append(col_type.from_binary(item, proto_version))
|
||||
values.append(col_type.from_binary(item, proto_version))
|
||||
|
||||
if len(result) < len(cls.subtypes):
|
||||
nones = [None] * (len(cls.subtypes) - len(result))
|
||||
result = result + nones
|
||||
if len(values) < len(cls.subtypes):
|
||||
nones = [None] * (len(cls.subtypes) - len(values))
|
||||
values = values + nones
|
||||
|
||||
return Result(*result)
|
||||
if cls.mapped_class:
|
||||
return cls.mapped_class(**dict(zip(cls.fieldnames, values)))
|
||||
else:
|
||||
Result = namedtuple(cls.typename, cls.fieldnames)
|
||||
return Result(*values)
|
||||
|
||||
@classmethod
|
||||
def serialize_safe(cls, val, protocol_version):
|
||||
|
||||
@@ -100,7 +100,8 @@ def _get_params(message_obj):
|
||||
)
|
||||
|
||||
|
||||
def decode_response(protocol_version, stream_id, flags, opcode, body, decompressor=None):
|
||||
def decode_response(protocol_version, user_type_map, stream_id, flags, opcode, body,
|
||||
decompressor=None):
|
||||
if flags & COMPRESSED_FLAG:
|
||||
if decompressor is None:
|
||||
raise Exception("No de-compressor available for compressed frame!")
|
||||
@@ -118,7 +119,7 @@ def decode_response(protocol_version, stream_id, flags, opcode, body, decompress
|
||||
log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
|
||||
|
||||
msg_class = _message_types_by_opcode[opcode]
|
||||
msg = msg_class.recv_body(body, protocol_version)
|
||||
msg = msg_class.recv_body(body, protocol_version, user_type_map)
|
||||
msg.stream_id = stream_id
|
||||
msg.trace_id = trace_id
|
||||
return msg
|
||||
@@ -138,7 +139,7 @@ class ErrorMessage(_MessageType, Exception):
|
||||
self.info = info
|
||||
|
||||
@classmethod
|
||||
def recv_body(cls, f, protocol_version):
|
||||
def recv_body(cls, f, protocol_version, user_type_map):
|
||||
code = read_int(f)
|
||||
msg = read_string(f)
|
||||
subcls = error_classes.get(code, cls)
|
||||
@@ -338,7 +339,7 @@ class ReadyMessage(_MessageType):
|
||||
name = 'READY'
|
||||
|
||||
@classmethod
|
||||
def recv_body(cls, f, protocol_version):
|
||||
def recv_body(cls, f, protocol_version, user_type_map):
|
||||
return cls()
|
||||
|
||||
|
||||
@@ -350,7 +351,7 @@ class AuthenticateMessage(_MessageType):
|
||||
self.authenticator = authenticator
|
||||
|
||||
@classmethod
|
||||
def recv_body(cls, f, protocol_version):
|
||||
def recv_body(cls, f, protocol_version, user_type_map):
|
||||
authname = read_string(f)
|
||||
return cls(authenticator=authname)
|
||||
|
||||
@@ -382,7 +383,7 @@ class AuthChallengeMessage(_MessageType):
|
||||
self.challenge = challenge
|
||||
|
||||
@classmethod
|
||||
def recv_body(cls, f, protocol_version):
|
||||
def recv_body(cls, f, protocol_version, user_type_map):
|
||||
return cls(read_longstring(f))
|
||||
|
||||
|
||||
@@ -405,7 +406,7 @@ class AuthSuccessMessage(_MessageType):
|
||||
self.token = token
|
||||
|
||||
@classmethod
|
||||
def recv_body(cls, f, protocol_version):
|
||||
def recv_body(cls, f, protocol_version, user_type_map):
|
||||
return cls(read_longstring(f))
|
||||
|
||||
|
||||
@@ -426,7 +427,7 @@ class SupportedMessage(_MessageType):
|
||||
self.options = options
|
||||
|
||||
@classmethod
|
||||
def recv_body(cls, f, protocol_version):
|
||||
def recv_body(cls, f, protocol_version, user_type_map):
|
||||
options = read_stringmultimap(f)
|
||||
cql_versions = options.pop('CQL_VERSION')
|
||||
return cls(cql_versions=cql_versions, options=options)
|
||||
@@ -547,13 +548,14 @@ class ResultMessage(_MessageType):
|
||||
self.paging_state = paging_state
|
||||
|
||||
@classmethod
|
||||
def recv_body(cls, f, protocol_version):
|
||||
def recv_body(cls, f, protocol_version, user_type_map):
|
||||
kind = read_int(f)
|
||||
paging_state = None
|
||||
if kind == RESULT_KIND_VOID:
|
||||
results = None
|
||||
elif kind == RESULT_KIND_ROWS:
|
||||
paging_state, results = cls.recv_results_rows(f, protocol_version)
|
||||
paging_state, results = cls.recv_results_rows(
|
||||
f, protocol_version, user_type_map)
|
||||
elif kind == RESULT_KIND_SET_KEYSPACE:
|
||||
ksname = read_string(f)
|
||||
results = ksname
|
||||
@@ -564,16 +566,17 @@ class ResultMessage(_MessageType):
|
||||
return cls(kind, results, paging_state)
|
||||
|
||||
@classmethod
|
||||
def recv_results_rows(cls, f, protocol_version):
|
||||
def recv_results_rows(cls, f, protocol_version, user_type_map):
|
||||
paging_state, column_metadata = cls.recv_results_metadata(f)
|
||||
rowcount = read_int(f)
|
||||
rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)]
|
||||
colnames = [c[2] for c in column_metadata]
|
||||
coltypes = [c[3] for c in column_metadata]
|
||||
return (
|
||||
paging_state,
|
||||
(colnames, [tuple(ctype.from_binary(val, protocol_version) for ctype, val in zip(coltypes, row))
|
||||
for row in rows]))
|
||||
parsed_rows = [
|
||||
tuple(ctype.from_binary(val, protocol_version)
|
||||
for ctype, val in zip(coltypes, row))
|
||||
for row in rows]
|
||||
return (paging_state, (colnames, parsed_rows))
|
||||
|
||||
@classmethod
|
||||
def recv_results_prepared(cls, f):
|
||||
@@ -582,7 +585,7 @@ class ResultMessage(_MessageType):
|
||||
return (query_id, column_metadata)
|
||||
|
||||
@classmethod
|
||||
def recv_results_metadata(cls, f):
|
||||
def recv_results_metadata(cls, f, user_type_map):
|
||||
flags = read_int(f)
|
||||
glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC)
|
||||
colcount = read_int(f)
|
||||
@@ -624,7 +627,7 @@ class ResultMessage(_MessageType):
|
||||
return {'change_type': change_type, 'keyspace': keyspace, 'table': table}
|
||||
|
||||
@classmethod
|
||||
def read_type(cls, f):
|
||||
def read_type(cls, f, user_type_map):
|
||||
optid = read_short(f)
|
||||
try:
|
||||
typeclass = cls._type_codes[optid]
|
||||
@@ -638,6 +641,14 @@ class ResultMessage(_MessageType):
|
||||
keysubtype = cls.read_type(f)
|
||||
valsubtype = cls.read_type(f)
|
||||
typeclass = typeclass.apply_parameters((keysubtype, valsubtype))
|
||||
elif typeclass == UserDefinedType:
|
||||
ks = cls.read_string(f)
|
||||
udt_name = cls.read_string(f)
|
||||
num_fields = cls.read_short(f)
|
||||
names_and_types = ((cls.read_string(f), cls.read_type(f))
|
||||
for _ in xrange(num_fields))
|
||||
mapped_class = user_type_map.get(ks, {}).get(udt_name)
|
||||
typeclass = typeclass.apply_parameters(udt_name, names_and_types, mapped_class)
|
||||
elif typeclass == CUSTOM_TYPE:
|
||||
classname = read_string(f)
|
||||
typeclass = lookup_casstype(classname)
|
||||
@@ -789,7 +800,7 @@ class EventMessage(_MessageType):
|
||||
self.event_args = event_args
|
||||
|
||||
@classmethod
|
||||
def recv_body(cls, f, protocol_version):
|
||||
def recv_body(cls, f, protocol_version, user_type_map):
|
||||
event_type = read_string(f).upper()
|
||||
if event_type in known_event_types:
|
||||
read_method = getattr(cls, 'recv_' + event_type.lower())
|
||||
|
||||
@@ -397,27 +397,15 @@ class ControlConnectionTest(unittest.TestCase):
|
||||
|
||||
def test_handle_schema_change(self):
|
||||
|
||||
for change_type in ('CREATED', 'DROPPED'):
|
||||
for change_type in ('CREATED', 'DROPPED', 'UPDATED'):
|
||||
event = {
|
||||
'change_type': change_type,
|
||||
'keyspace': 'ks1',
|
||||
'table': 'table1'
|
||||
}
|
||||
self.control_connection._handle_schema_change(event)
|
||||
self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, 'ks1')
|
||||
self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, 'ks1', 'table1', None)
|
||||
|
||||
event['table'] = None
|
||||
self.control_connection._handle_schema_change(event)
|
||||
self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, None)
|
||||
|
||||
event = {
|
||||
'change_type': 'UPDATED',
|
||||
'keyspace': 'ks1',
|
||||
'table': 'table1'
|
||||
}
|
||||
self.control_connection._handle_schema_change(event)
|
||||
self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, 'ks1', 'table1')
|
||||
|
||||
event['table'] = None
|
||||
self.control_connection._handle_schema_change(event)
|
||||
self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, 'ks1', None)
|
||||
self.cluster.executor.submit.assert_called_with(self.control_connection.refresh_schema, 'ks1', None, None)
|
||||
|
||||
@@ -105,7 +105,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
kind=RESULT_KIND_SCHEMA_CHANGE,
|
||||
results={'keyspace': "keyspace1", "table": "table1"})
|
||||
rf._set_result(result)
|
||||
session.submit.assert_called_once_with(ANY, 'keyspace1', 'table1', ANY, rf)
|
||||
session.submit.assert_called_once_with(ANY, 'keyspace1', 'table1', None, ANY, rf)
|
||||
|
||||
def test_other_result_message_kind(self):
|
||||
session = self.make_session()
|
||||
|
||||
Reference in New Issue
Block a user