diff --git a/cassandra/cluster.py b/cassandra/cluster.py index a79aaab7..0fdcfdb0 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -31,7 +31,9 @@ from cassandra.decoder import (QueryMessage, ResultMessage, PrepareMessage, ExecuteMessage, PreparedQueryNotFound, IsBootstrappingErrorMessage, - BatchMessage) + BatchMessage, RESULT_KIND_PREPARED, + RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS, + RESULT_KIND_SCHEMA_CHANGE) from cassandra.metadata import Metadata from cassandra.metrics import Metrics from cassandra.policies import (RoundRobinPolicy, SimpleConvictionPolicy, @@ -855,7 +857,7 @@ class Cluster(object): responses = connection.wait_for_responses(*messages, timeout=5.0) for response in responses: if (not isinstance(response, ResultMessage) or - response.kind != ResultMessage.KIND_PREPARED): + response.kind != RESULT_KIND_PREPARED): log.debug("Got unexpected response when preparing " "statement on host %s: %r", host, response) @@ -1068,18 +1070,15 @@ class Session(object): if parameters: parameters = encode_params(parameters) message = QueryMessage( - query=query.query_string, consistency_level=query.consistency_level, - parameters=parameters) + query.query_string, query.consistency_level, parameters) else: query_string = query.query_string if parameters: query_string = bind_params(query.query_string, parameters) - message = QueryMessage(query=query_string, consistency_level=query.consistency_level) + message = QueryMessage(query_string, query.consistency_level) elif isinstance(query, BoundStatement): message = ExecuteMessage( - query_id=query.prepared_statement.query_id, - query_params=query.values, - consistency_level=query.consistency_level) + query.prepared_statement.query_id, query.values, query.consistency_level) prepared_statement = query.prepared_statement elif isinstance(query, BatchStatement): if self._protocol_version < 2: @@ -1088,9 +1087,8 @@ class Session(object): "2 or higher (supported in Cassandra 2.0 and higher). Consider " "setting Cluster.protocol_version to 2 to support this operation.") message = BatchMessage( - batch_type=query.batch_type, - queries=query._statements_and_values, - consistency_level=query.consistency_level) + query.batch_type, query._statements_and_parameters, + query.consistency_level) if trace: message.tracing = True @@ -1949,7 +1947,7 @@ class ResponseFuture(object): self._query_trace = QueryTrace(trace_id, self.session) if isinstance(response, ResultMessage): - if response.kind == ResultMessage.KIND_SET_KEYSPACE: + if response.kind == RESULT_KIND_SET_KEYSPACE: session = getattr(self, 'session', None) # since we're running on the event loop thread, we need to # use a non-blocking method for setting the keyspace on @@ -1961,7 +1959,7 @@ class ResponseFuture(object): if session: session._set_keyspace_for_all_pools( response.results, self._set_keyspace_completed) - elif response.kind == ResultMessage.KIND_SCHEMA_CHANGE: + elif response.kind == RESULT_KIND_SCHEMA_CHANGE: # refresh the schema before responding, but do it in another # thread instead of the event loop thread self.session.submit( @@ -1972,7 +1970,7 @@ class ResponseFuture(object): self) else: results = getattr(response, 'results', None) - if results is not None and response.kind == ResultMessage.KIND_ROWS: + if results is not None and response.kind == RESULT_KIND_ROWS: results = self.row_factory(*results) self._set_final_result(results) elif isinstance(response, ErrorMessage): @@ -2107,7 +2105,7 @@ class ResponseFuture(object): return if isinstance(response, ResultMessage): - if response.kind == ResultMessage.KIND_PREPARED: + if response.kind == RESULT_KIND_PREPARED: # use self._query to re-use the same host and # at the same time properly borrow the connection request_id = self._query(self._current_host) diff --git a/cassandra/decoder.py b/cassandra/decoder.py index 0edb3009..83a1f7e3 100644 --- a/cassandra/decoder.py +++ b/cassandra/decoder.py @@ -48,19 +48,9 @@ class _register_msg_type(type): class _MessageType(object): __metaclass__ = _register_msg_type - params = () tracing = False - def __init__(self, **kwargs): - for pname in self.params: - try: - pval = kwargs[pname] - except KeyError: - raise ValueError("%s instances need the %s keyword parameter" - % (self.__class__.__name__, pname)) - setattr(self, pname, pval) - def to_string(self, stream_id, protocol_version, compression=None): body = StringIO() self.send_body(body) @@ -77,11 +67,17 @@ class _MessageType(object): return ''.join(msg_parts) def __str__(self): - paramstrs = ['%s=%r' % (pname, getattr(self, pname)) for pname in self.params] + paramstrs = ['%s=%r' % (pname, getattr(self, pname)) for pname in _get_params(self)] return '<%s(%s)>' % (self.__class__.__name__, ', '.join(paramstrs)) __repr__ = __str__ +def _get_params(message_obj): + base_attrs = dir(_MessageType) + return [a for a in dir(message_obj) + if a not in base_attrs and not a.startswith('_') and not callable(getattr(message_obj, a))] + + def decode_response(stream_id, flags, opcode, body, decompressor=None): if flags & 0x01: if decompressor is None: @@ -112,9 +108,13 @@ error_classes = {} class ErrorMessage(_MessageType, Exception): opcode = 0x00 name = 'ERROR' - params = ('code', 'message', 'info') summary = 'Unknown' + def __init__(self, code, message, info): + self.code = code + self.message = message + self.info = info + @classmethod def recv_body(cls, f): code = read_int(f) @@ -290,13 +290,16 @@ class AlreadyExistsException(ConfigurationException): class StartupMessage(_MessageType): opcode = 0x01 name = 'STARTUP' - params = ('cqlversion', 'options') KNOWN_OPTION_KEYS = set(( 'CQL_VERSION', 'COMPRESSION', )) + def __init__(self, cqlversion, options): + self.cqlversion = cqlversion + self.options = options + def send_body(self, f): optmap = self.options.copy() optmap['CQL_VERSION'] = self.cqlversion @@ -306,7 +309,6 @@ class StartupMessage(_MessageType): class ReadyMessage(_MessageType): opcode = 0x02 name = 'READY' - params = () @classmethod def recv_body(cls, f): @@ -316,7 +318,9 @@ class ReadyMessage(_MessageType): class AuthenticateMessage(_MessageType): opcode = 0x03 name = 'AUTHENTICATE' - params = ('authenticator',) + + def __init__(self, authenticator): + self.authenticator = authenticator @classmethod def recv_body(cls, f): @@ -327,7 +331,9 @@ class AuthenticateMessage(_MessageType): class CredentialsMessage(_MessageType): opcode = 0x04 name = 'CREDENTIALS' - params = ('creds',) + + def __init__(self, creds): + self.creds = creds def send_body(self, f): write_short(f, len(self.creds)) @@ -339,7 +345,6 @@ class CredentialsMessage(_MessageType): class OptionsMessage(_MessageType): opcode = 0x05 name = 'OPTIONS' - params = () def send_body(self, f): pass @@ -348,7 +353,10 @@ class OptionsMessage(_MessageType): class SupportedMessage(_MessageType): opcode = 0x06 name = 'SUPPORTED' - params = ('cql_versions', 'options',) + + def __init__(self, cql_versions, options): + self.cql_versions = cql_versions + self.options = options @classmethod def recv_body(cls, f): @@ -360,7 +368,17 @@ class SupportedMessage(_MessageType): class QueryMessage(_MessageType): opcode = 0x07 name = 'QUERY' - params = ('query', 'consistency_level',) + + _VALUES_FLAG = 0x01 + _SKIP_METADATA_FLAG = 0x01 + _PAGE_SIZE_FLAG = 0x04 + _WITH_PAGING_STATE_SIZE_FLAG = 0x08 + _WITH_SERIAL_CONSISTENCY_FLAG = 0x10 + + def __init__(self, query, consistency_level, values=None): + self.query = query + self.consistency_level = consistency_level + self.values = values def send_body(self, f): write_longstring(f, self.query) @@ -369,19 +387,18 @@ class QueryMessage(_MessageType): CUSTOM_TYPE = object() +RESULT_KIND_VOID = 0x0001 +RESULT_KIND_ROWS = 0x0002 +RESULT_KIND_SET_KEYSPACE = 0x0003 +RESULT_KIND_PREPARED = 0x0004 +RESULT_KIND_SCHEMA_CHANGE = 0x0005 + class ResultMessage(_MessageType): opcode = 0x08 name = 'RESULT' - params = ('kind', 'results',) - KIND_VOID = 0x0001 - KIND_ROWS = 0x0002 - KIND_SET_KEYSPACE = 0x0003 - KIND_PREPARED = 0x0004 - KIND_SCHEMA_CHANGE = 0x0005 - - type_codes = { + _type_codes = { 0x0000: CUSTOM_TYPE, 0x0001: AsciiType, 0x0002: LongType, @@ -404,21 +421,25 @@ class ResultMessage(_MessageType): 0x0022: SetType, } - FLAGS_GLOBAL_TABLES_SPEC = 0x0001 + _FLAGS_GLOBAL_TABLES_SPEC = 0x0001 + + def __init__(self, kind, results): + self.kind = kind + self.results = results @classmethod def recv_body(cls, f): kind = read_int(f) - if kind == cls.KIND_VOID: + if kind == RESULT_KIND_VOID: results = None - elif kind == cls.KIND_ROWS: + elif kind == RESULT_KIND_ROWS: results = cls.recv_results_rows(f) - elif kind == cls.KIND_SET_KEYSPACE: + elif kind == RESULT_KIND_SET_KEYSPACE: ksname = read_string(f) results = ksname - elif kind == cls.KIND_PREPARED: + elif kind == RESULT_KIND_PREPARED: results = cls.recv_results_prepared(f) - elif kind == cls.KIND_SCHEMA_CHANGE: + elif kind == RESULT_KIND_SCHEMA_CHANGE: results = cls.recv_results_schema_change(f) return cls(kind=kind, results=results) @@ -441,7 +462,7 @@ class ResultMessage(_MessageType): @classmethod def recv_results_metadata(cls, f): flags = read_int(f) - glob_tblspec = bool(flags & cls.FLAGS_GLOBAL_TABLES_SPEC) + glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC) colcount = read_int(f) if glob_tblspec: ksname = read_string(f) @@ -470,7 +491,7 @@ class ResultMessage(_MessageType): def read_type(cls, f): optid = read_short(f) try: - typeclass = cls.type_codes[optid] + typeclass = cls._type_codes[optid] except KeyError: raise NotSupportedError("Unknown data type code 0x%04x. Have to skip" " entire result set." % (optid,)) @@ -495,7 +516,9 @@ class ResultMessage(_MessageType): class PrepareMessage(_MessageType): opcode = 0x09 name = 'PREPARE' - params = ('query',) + + def __init__(self, query): + self.query = query def send_body(self, f): write_longstring(f, self.query) @@ -504,7 +527,11 @@ class PrepareMessage(_MessageType): class ExecuteMessage(_MessageType): opcode = 0x0A name = 'EXECUTE' - params = ('query_id', 'query_params', 'consistency_level',) + + def __init__(self, query_id, query_params, consistency_level): + self.query_id = query_id + self.query_params = query_params + self.consistency_level = consistency_level def send_body(self, f): write_string(f, self.query_id) @@ -517,7 +544,11 @@ class ExecuteMessage(_MessageType): class BatchMessage(_MessageType): opcode = 0x0D name = 'BATCH' - params = ('batch_type', 'queries', 'consistency_level',) + + def __init__(self, batch_type, queries, consistency_level): + self.batch_type = batch_type + self.queries = queries + self.consistency_level = consistency_level def send_body(self, f): write_byte(f, self.batch_type.value) @@ -546,16 +577,21 @@ known_event_types = frozenset(( class RegisterMessage(_MessageType): opcode = 0x0B name = 'REGISTER' - params = ('event_list',) - def send_body(self, f): + def __init__(self, event_list): + self.event_list = event_list + + def send_body(self, f, protocol_version): write_stringlist(f, self.event_list) class EventMessage(_MessageType): opcode = 0x0C name = 'EVENT' - params = ('event_type', 'event_args') + + def __init__(self, event_type, event_args): + self.event_type = event_type + self.event_args = event_args @classmethod def recv_body(cls, f): diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 2b7aec44..2a5f42bc 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -8,7 +8,7 @@ from mock import Mock, ANY from concurrent.futures import ThreadPoolExecutor from cassandra import OperationTimedOut -from cassandra.decoder import ResultMessage +from cassandra.decoder import ResultMessage, RESULT_KIND_ROWS from cassandra.cluster import ControlConnection, Cluster, _Scheduler from cassandra.pool import Host from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy, @@ -16,6 +16,7 @@ from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy, PEER_IP = "foobar" + class MockMetadata(object): def __init__(self): @@ -91,9 +92,9 @@ class MockConnection(object): def wait_for_responses(self, peer_query, local_query, timeout=None): local_response = ResultMessage( - kind=ResultMessage.KIND_ROWS, results=self.local_results) + kind=RESULT_KIND_ROWS, results=self.local_results) peer_response = ResultMessage( - kind=ResultMessage.KIND_ROWS, results=self.peer_results) + kind=RESULT_KIND_ROWS, results=self.peer_results) return (peer_response, local_response) diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index ea05a86f..b617b6ef 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -11,11 +11,14 @@ from cassandra.connection import ConnectionException from cassandra.decoder import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableErrorMessage, ResultMessage, QueryMessage, OverloadedErrorMessage, IsBootstrappingErrorMessage, - PreparedQueryNotFound, PrepareMessage) + PreparedQueryNotFound, PrepareMessage, + RESULT_KIND_ROWS, RESULT_KIND_SET_KEYSPACE, + RESULT_KIND_SCHEMA_CHANGE) from cassandra.policies import RetryPolicy from cassandra.pool import NoConnectionsAvailable from cassandra.query import SimpleStatement + class ResponseFutureTests(unittest.TestCase): def make_basic_session(self): @@ -46,7 +49,7 @@ class ResponseFutureTests(unittest.TestCase): connection = pool.borrow_connection.return_value connection.send_msg.assert_called_once_with(rf.message, cb=ANY) - response = Mock(spec=ResultMessage, kind=ResultMessage.KIND_ROWS, results=[{'col': 'val'}]) + response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}]) rf._set_result(response) result = rf.result() @@ -65,7 +68,7 @@ class ResponseFutureTests(unittest.TestCase): rf.send_request() result = Mock(spec=ResultMessage, - kind=ResultMessage.KIND_SET_KEYSPACE, + kind=RESULT_KIND_SET_KEYSPACE, results="keyspace1") rf._set_result(result) rf._set_keyspace_completed({}) @@ -77,7 +80,7 @@ class ResponseFutureTests(unittest.TestCase): rf.send_request() result = Mock(spec=ResultMessage, - kind=ResultMessage.KIND_SCHEMA_CHANGE, + kind=RESULT_KIND_SCHEMA_CHANGE, results={'keyspace': "keyspace1", "table": "table1"}) rf._set_result(result) session.submit.assert_called_once_with(ANY, 'keyspace1', 'table1', ANY, rf) @@ -256,7 +259,7 @@ class ResponseFutureTests(unittest.TestCase): rf = self.make_response_future(session) rf.send_request() - response = Mock(spec=ResultMessage, kind=ResultMessage.KIND_ROWS, results=[{'col': 'val'}]) + response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}]) rf._set_result(response) result = rf.result() @@ -277,7 +280,7 @@ class ResponseFutureTests(unittest.TestCase): rf = self.make_response_future(session) rf.send_request() - response = Mock(spec=ResultMessage, kind=ResultMessage.KIND_ROWS, results=[{'col': 'val'}]) + response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}]) rf._set_result(response) self.assertEqual(rf.result(), [{'col': 'val'}]) @@ -291,7 +294,7 @@ class ResponseFutureTests(unittest.TestCase): rf.add_callback(self.assertEqual, [{'col': 'val'}]) - response = Mock(spec=ResultMessage, kind=ResultMessage.KIND_ROWS, results=[{'col': 'val'}]) + response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}]) rf._set_result(response) result = rf.result() @@ -346,7 +349,7 @@ class ResponseFutureTests(unittest.TestCase): callback=self.assertEquals, callback_args=([{'col': 'val'}],), errback=self.assertIsInstance, errback_args=(Exception,)) - response = Mock(spec=ResultMessage, kind=ResultMessage.KIND_ROWS, results=[{'col': 'val'}]) + response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}]) rf._set_result(response) self.assertEqual(rf.result(), [{'col': 'val'}])