Switch from class-level params to init args

This commit is contained in:
Tyler Hobbs
2014-03-06 12:50:11 -06:00
parent 78b47453d3
commit 4bc59a1207
4 changed files with 105 additions and 67 deletions

View File

@@ -31,7 +31,9 @@ from cassandra.decoder import (QueryMessage, ResultMessage,
PrepareMessage, ExecuteMessage, PrepareMessage, ExecuteMessage,
PreparedQueryNotFound, PreparedQueryNotFound,
IsBootstrappingErrorMessage, IsBootstrappingErrorMessage,
BatchMessage) BatchMessage, RESULT_KIND_PREPARED,
RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS,
RESULT_KIND_SCHEMA_CHANGE)
from cassandra.metadata import Metadata from cassandra.metadata import Metadata
from cassandra.metrics import Metrics from cassandra.metrics import Metrics
from cassandra.policies import (RoundRobinPolicy, SimpleConvictionPolicy, from cassandra.policies import (RoundRobinPolicy, SimpleConvictionPolicy,
@@ -855,7 +857,7 @@ class Cluster(object):
responses = connection.wait_for_responses(*messages, timeout=5.0) responses = connection.wait_for_responses(*messages, timeout=5.0)
for response in responses: for response in responses:
if (not isinstance(response, ResultMessage) or if (not isinstance(response, ResultMessage) or
response.kind != ResultMessage.KIND_PREPARED): response.kind != RESULT_KIND_PREPARED):
log.debug("Got unexpected response when preparing " log.debug("Got unexpected response when preparing "
"statement on host %s: %r", host, response) "statement on host %s: %r", host, response)
@@ -1068,18 +1070,15 @@ class Session(object):
if parameters: if parameters:
parameters = encode_params(parameters) parameters = encode_params(parameters)
message = QueryMessage( message = QueryMessage(
query=query.query_string, consistency_level=query.consistency_level, query.query_string, query.consistency_level, parameters)
parameters=parameters)
else: else:
query_string = query.query_string query_string = query.query_string
if parameters: if parameters:
query_string = bind_params(query.query_string, parameters) query_string = bind_params(query.query_string, parameters)
message = QueryMessage(query=query_string, consistency_level=query.consistency_level) message = QueryMessage(query_string, query.consistency_level)
elif isinstance(query, BoundStatement): elif isinstance(query, BoundStatement):
message = ExecuteMessage( message = ExecuteMessage(
query_id=query.prepared_statement.query_id, query.prepared_statement.query_id, query.values, query.consistency_level)
query_params=query.values,
consistency_level=query.consistency_level)
prepared_statement = query.prepared_statement prepared_statement = query.prepared_statement
elif isinstance(query, BatchStatement): elif isinstance(query, BatchStatement):
if self._protocol_version < 2: if self._protocol_version < 2:
@@ -1088,9 +1087,8 @@ class Session(object):
"2 or higher (supported in Cassandra 2.0 and higher). Consider " "2 or higher (supported in Cassandra 2.0 and higher). Consider "
"setting Cluster.protocol_version to 2 to support this operation.") "setting Cluster.protocol_version to 2 to support this operation.")
message = BatchMessage( message = BatchMessage(
batch_type=query.batch_type, query.batch_type, query._statements_and_parameters,
queries=query._statements_and_values, query.consistency_level)
consistency_level=query.consistency_level)
if trace: if trace:
message.tracing = True message.tracing = True
@@ -1949,7 +1947,7 @@ class ResponseFuture(object):
self._query_trace = QueryTrace(trace_id, self.session) self._query_trace = QueryTrace(trace_id, self.session)
if isinstance(response, ResultMessage): if isinstance(response, ResultMessage):
if response.kind == ResultMessage.KIND_SET_KEYSPACE: if response.kind == RESULT_KIND_SET_KEYSPACE:
session = getattr(self, 'session', None) session = getattr(self, 'session', None)
# since we're running on the event loop thread, we need to # since we're running on the event loop thread, we need to
# use a non-blocking method for setting the keyspace on # use a non-blocking method for setting the keyspace on
@@ -1961,7 +1959,7 @@ class ResponseFuture(object):
if session: if session:
session._set_keyspace_for_all_pools( session._set_keyspace_for_all_pools(
response.results, self._set_keyspace_completed) response.results, self._set_keyspace_completed)
elif response.kind == ResultMessage.KIND_SCHEMA_CHANGE: elif response.kind == RESULT_KIND_SCHEMA_CHANGE:
# refresh the schema before responding, but do it in another # refresh the schema before responding, but do it in another
# thread instead of the event loop thread # thread instead of the event loop thread
self.session.submit( self.session.submit(
@@ -1972,7 +1970,7 @@ class ResponseFuture(object):
self) self)
else: else:
results = getattr(response, 'results', None) results = getattr(response, 'results', None)
if results is not None and response.kind == ResultMessage.KIND_ROWS: if results is not None and response.kind == RESULT_KIND_ROWS:
results = self.row_factory(*results) results = self.row_factory(*results)
self._set_final_result(results) self._set_final_result(results)
elif isinstance(response, ErrorMessage): elif isinstance(response, ErrorMessage):
@@ -2107,7 +2105,7 @@ class ResponseFuture(object):
return return
if isinstance(response, ResultMessage): if isinstance(response, ResultMessage):
if response.kind == ResultMessage.KIND_PREPARED: if response.kind == RESULT_KIND_PREPARED:
# use self._query to re-use the same host and # use self._query to re-use the same host and
# at the same time properly borrow the connection # at the same time properly borrow the connection
request_id = self._query(self._current_host) request_id = self._query(self._current_host)

View File

@@ -48,19 +48,9 @@ class _register_msg_type(type):
class _MessageType(object): class _MessageType(object):
__metaclass__ = _register_msg_type __metaclass__ = _register_msg_type
params = ()
tracing = False tracing = False
def __init__(self, **kwargs):
for pname in self.params:
try:
pval = kwargs[pname]
except KeyError:
raise ValueError("%s instances need the %s keyword parameter"
% (self.__class__.__name__, pname))
setattr(self, pname, pval)
def to_string(self, stream_id, protocol_version, compression=None): def to_string(self, stream_id, protocol_version, compression=None):
body = StringIO() body = StringIO()
self.send_body(body) self.send_body(body)
@@ -77,11 +67,17 @@ class _MessageType(object):
return ''.join(msg_parts) return ''.join(msg_parts)
def __str__(self): def __str__(self):
paramstrs = ['%s=%r' % (pname, getattr(self, pname)) for pname in self.params] paramstrs = ['%s=%r' % (pname, getattr(self, pname)) for pname in _get_params(self)]
return '<%s(%s)>' % (self.__class__.__name__, ', '.join(paramstrs)) return '<%s(%s)>' % (self.__class__.__name__, ', '.join(paramstrs))
__repr__ = __str__ __repr__ = __str__
def _get_params(message_obj):
base_attrs = dir(_MessageType)
return [a for a in dir(message_obj)
if a not in base_attrs and not a.startswith('_') and not callable(getattr(message_obj, a))]
def decode_response(stream_id, flags, opcode, body, decompressor=None): def decode_response(stream_id, flags, opcode, body, decompressor=None):
if flags & 0x01: if flags & 0x01:
if decompressor is None: if decompressor is None:
@@ -112,9 +108,13 @@ error_classes = {}
class ErrorMessage(_MessageType, Exception): class ErrorMessage(_MessageType, Exception):
opcode = 0x00 opcode = 0x00
name = 'ERROR' name = 'ERROR'
params = ('code', 'message', 'info')
summary = 'Unknown' summary = 'Unknown'
def __init__(self, code, message, info):
self.code = code
self.message = message
self.info = info
@classmethod @classmethod
def recv_body(cls, f): def recv_body(cls, f):
code = read_int(f) code = read_int(f)
@@ -290,13 +290,16 @@ class AlreadyExistsException(ConfigurationException):
class StartupMessage(_MessageType): class StartupMessage(_MessageType):
opcode = 0x01 opcode = 0x01
name = 'STARTUP' name = 'STARTUP'
params = ('cqlversion', 'options')
KNOWN_OPTION_KEYS = set(( KNOWN_OPTION_KEYS = set((
'CQL_VERSION', 'CQL_VERSION',
'COMPRESSION', 'COMPRESSION',
)) ))
def __init__(self, cqlversion, options):
self.cqlversion = cqlversion
self.options = options
def send_body(self, f): def send_body(self, f):
optmap = self.options.copy() optmap = self.options.copy()
optmap['CQL_VERSION'] = self.cqlversion optmap['CQL_VERSION'] = self.cqlversion
@@ -306,7 +309,6 @@ class StartupMessage(_MessageType):
class ReadyMessage(_MessageType): class ReadyMessage(_MessageType):
opcode = 0x02 opcode = 0x02
name = 'READY' name = 'READY'
params = ()
@classmethod @classmethod
def recv_body(cls, f): def recv_body(cls, f):
@@ -316,7 +318,9 @@ class ReadyMessage(_MessageType):
class AuthenticateMessage(_MessageType): class AuthenticateMessage(_MessageType):
opcode = 0x03 opcode = 0x03
name = 'AUTHENTICATE' name = 'AUTHENTICATE'
params = ('authenticator',)
def __init__(self, authenticator):
self.authenticator = authenticator
@classmethod @classmethod
def recv_body(cls, f): def recv_body(cls, f):
@@ -327,7 +331,9 @@ class AuthenticateMessage(_MessageType):
class CredentialsMessage(_MessageType): class CredentialsMessage(_MessageType):
opcode = 0x04 opcode = 0x04
name = 'CREDENTIALS' name = 'CREDENTIALS'
params = ('creds',)
def __init__(self, creds):
self.creds = creds
def send_body(self, f): def send_body(self, f):
write_short(f, len(self.creds)) write_short(f, len(self.creds))
@@ -339,7 +345,6 @@ class CredentialsMessage(_MessageType):
class OptionsMessage(_MessageType): class OptionsMessage(_MessageType):
opcode = 0x05 opcode = 0x05
name = 'OPTIONS' name = 'OPTIONS'
params = ()
def send_body(self, f): def send_body(self, f):
pass pass
@@ -348,7 +353,10 @@ class OptionsMessage(_MessageType):
class SupportedMessage(_MessageType): class SupportedMessage(_MessageType):
opcode = 0x06 opcode = 0x06
name = 'SUPPORTED' name = 'SUPPORTED'
params = ('cql_versions', 'options',)
def __init__(self, cql_versions, options):
self.cql_versions = cql_versions
self.options = options
@classmethod @classmethod
def recv_body(cls, f): def recv_body(cls, f):
@@ -360,7 +368,17 @@ class SupportedMessage(_MessageType):
class QueryMessage(_MessageType): class QueryMessage(_MessageType):
opcode = 0x07 opcode = 0x07
name = 'QUERY' name = 'QUERY'
params = ('query', 'consistency_level',)
_VALUES_FLAG = 0x01
_SKIP_METADATA_FLAG = 0x01
_PAGE_SIZE_FLAG = 0x04
_WITH_PAGING_STATE_SIZE_FLAG = 0x08
_WITH_SERIAL_CONSISTENCY_FLAG = 0x10
def __init__(self, query, consistency_level, values=None):
self.query = query
self.consistency_level = consistency_level
self.values = values
def send_body(self, f): def send_body(self, f):
write_longstring(f, self.query) write_longstring(f, self.query)
@@ -369,19 +387,18 @@ class QueryMessage(_MessageType):
CUSTOM_TYPE = object() CUSTOM_TYPE = object()
RESULT_KIND_VOID = 0x0001
RESULT_KIND_ROWS = 0x0002
RESULT_KIND_SET_KEYSPACE = 0x0003
RESULT_KIND_PREPARED = 0x0004
RESULT_KIND_SCHEMA_CHANGE = 0x0005
class ResultMessage(_MessageType): class ResultMessage(_MessageType):
opcode = 0x08 opcode = 0x08
name = 'RESULT' name = 'RESULT'
params = ('kind', 'results',)
KIND_VOID = 0x0001 _type_codes = {
KIND_ROWS = 0x0002
KIND_SET_KEYSPACE = 0x0003
KIND_PREPARED = 0x0004
KIND_SCHEMA_CHANGE = 0x0005
type_codes = {
0x0000: CUSTOM_TYPE, 0x0000: CUSTOM_TYPE,
0x0001: AsciiType, 0x0001: AsciiType,
0x0002: LongType, 0x0002: LongType,
@@ -404,21 +421,25 @@ class ResultMessage(_MessageType):
0x0022: SetType, 0x0022: SetType,
} }
FLAGS_GLOBAL_TABLES_SPEC = 0x0001 _FLAGS_GLOBAL_TABLES_SPEC = 0x0001
def __init__(self, kind, results):
self.kind = kind
self.results = results
@classmethod @classmethod
def recv_body(cls, f): def recv_body(cls, f):
kind = read_int(f) kind = read_int(f)
if kind == cls.KIND_VOID: if kind == RESULT_KIND_VOID:
results = None results = None
elif kind == cls.KIND_ROWS: elif kind == RESULT_KIND_ROWS:
results = cls.recv_results_rows(f) results = cls.recv_results_rows(f)
elif kind == cls.KIND_SET_KEYSPACE: elif kind == RESULT_KIND_SET_KEYSPACE:
ksname = read_string(f) ksname = read_string(f)
results = ksname results = ksname
elif kind == cls.KIND_PREPARED: elif kind == RESULT_KIND_PREPARED:
results = cls.recv_results_prepared(f) results = cls.recv_results_prepared(f)
elif kind == cls.KIND_SCHEMA_CHANGE: elif kind == RESULT_KIND_SCHEMA_CHANGE:
results = cls.recv_results_schema_change(f) results = cls.recv_results_schema_change(f)
return cls(kind=kind, results=results) return cls(kind=kind, results=results)
@@ -441,7 +462,7 @@ class ResultMessage(_MessageType):
@classmethod @classmethod
def recv_results_metadata(cls, f): def recv_results_metadata(cls, f):
flags = read_int(f) flags = read_int(f)
glob_tblspec = bool(flags & cls.FLAGS_GLOBAL_TABLES_SPEC) glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC)
colcount = read_int(f) colcount = read_int(f)
if glob_tblspec: if glob_tblspec:
ksname = read_string(f) ksname = read_string(f)
@@ -470,7 +491,7 @@ class ResultMessage(_MessageType):
def read_type(cls, f): def read_type(cls, f):
optid = read_short(f) optid = read_short(f)
try: try:
typeclass = cls.type_codes[optid] typeclass = cls._type_codes[optid]
except KeyError: except KeyError:
raise NotSupportedError("Unknown data type code 0x%04x. Have to skip" raise NotSupportedError("Unknown data type code 0x%04x. Have to skip"
" entire result set." % (optid,)) " entire result set." % (optid,))
@@ -495,7 +516,9 @@ class ResultMessage(_MessageType):
class PrepareMessage(_MessageType): class PrepareMessage(_MessageType):
opcode = 0x09 opcode = 0x09
name = 'PREPARE' name = 'PREPARE'
params = ('query',)
def __init__(self, query):
self.query = query
def send_body(self, f): def send_body(self, f):
write_longstring(f, self.query) write_longstring(f, self.query)
@@ -504,7 +527,11 @@ class PrepareMessage(_MessageType):
class ExecuteMessage(_MessageType): class ExecuteMessage(_MessageType):
opcode = 0x0A opcode = 0x0A
name = 'EXECUTE' name = 'EXECUTE'
params = ('query_id', 'query_params', 'consistency_level',)
def __init__(self, query_id, query_params, consistency_level):
self.query_id = query_id
self.query_params = query_params
self.consistency_level = consistency_level
def send_body(self, f): def send_body(self, f):
write_string(f, self.query_id) write_string(f, self.query_id)
@@ -517,7 +544,11 @@ class ExecuteMessage(_MessageType):
class BatchMessage(_MessageType): class BatchMessage(_MessageType):
opcode = 0x0D opcode = 0x0D
name = 'BATCH' name = 'BATCH'
params = ('batch_type', 'queries', 'consistency_level',)
def __init__(self, batch_type, queries, consistency_level):
self.batch_type = batch_type
self.queries = queries
self.consistency_level = consistency_level
def send_body(self, f): def send_body(self, f):
write_byte(f, self.batch_type.value) write_byte(f, self.batch_type.value)
@@ -546,16 +577,21 @@ known_event_types = frozenset((
class RegisterMessage(_MessageType): class RegisterMessage(_MessageType):
opcode = 0x0B opcode = 0x0B
name = 'REGISTER' name = 'REGISTER'
params = ('event_list',)
def send_body(self, f): def __init__(self, event_list):
self.event_list = event_list
def send_body(self, f, protocol_version):
write_stringlist(f, self.event_list) write_stringlist(f, self.event_list)
class EventMessage(_MessageType): class EventMessage(_MessageType):
opcode = 0x0C opcode = 0x0C
name = 'EVENT' name = 'EVENT'
params = ('event_type', 'event_args')
def __init__(self, event_type, event_args):
self.event_type = event_type
self.event_args = event_args
@classmethod @classmethod
def recv_body(cls, f): def recv_body(cls, f):

View File

@@ -8,7 +8,7 @@ from mock import Mock, ANY
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from cassandra import OperationTimedOut from cassandra import OperationTimedOut
from cassandra.decoder import ResultMessage from cassandra.decoder import ResultMessage, RESULT_KIND_ROWS
from cassandra.cluster import ControlConnection, Cluster, _Scheduler from cassandra.cluster import ControlConnection, Cluster, _Scheduler
from cassandra.pool import Host from cassandra.pool import Host
from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy, from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy,
@@ -16,6 +16,7 @@ from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy,
PEER_IP = "foobar" PEER_IP = "foobar"
class MockMetadata(object): class MockMetadata(object):
def __init__(self): def __init__(self):
@@ -91,9 +92,9 @@ class MockConnection(object):
def wait_for_responses(self, peer_query, local_query, timeout=None): def wait_for_responses(self, peer_query, local_query, timeout=None):
local_response = ResultMessage( local_response = ResultMessage(
kind=ResultMessage.KIND_ROWS, results=self.local_results) kind=RESULT_KIND_ROWS, results=self.local_results)
peer_response = ResultMessage( peer_response = ResultMessage(
kind=ResultMessage.KIND_ROWS, results=self.peer_results) kind=RESULT_KIND_ROWS, results=self.peer_results)
return (peer_response, local_response) return (peer_response, local_response)

View File

@@ -11,11 +11,14 @@ from cassandra.connection import ConnectionException
from cassandra.decoder import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, from cassandra.decoder import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage,
UnavailableErrorMessage, ResultMessage, QueryMessage, UnavailableErrorMessage, ResultMessage, QueryMessage,
OverloadedErrorMessage, IsBootstrappingErrorMessage, OverloadedErrorMessage, IsBootstrappingErrorMessage,
PreparedQueryNotFound, PrepareMessage) PreparedQueryNotFound, PrepareMessage,
RESULT_KIND_ROWS, RESULT_KIND_SET_KEYSPACE,
RESULT_KIND_SCHEMA_CHANGE)
from cassandra.policies import RetryPolicy from cassandra.policies import RetryPolicy
from cassandra.pool import NoConnectionsAvailable from cassandra.pool import NoConnectionsAvailable
from cassandra.query import SimpleStatement from cassandra.query import SimpleStatement
class ResponseFutureTests(unittest.TestCase): class ResponseFutureTests(unittest.TestCase):
def make_basic_session(self): def make_basic_session(self):
@@ -46,7 +49,7 @@ class ResponseFutureTests(unittest.TestCase):
connection = pool.borrow_connection.return_value connection = pool.borrow_connection.return_value
connection.send_msg.assert_called_once_with(rf.message, cb=ANY) connection.send_msg.assert_called_once_with(rf.message, cb=ANY)
response = Mock(spec=ResultMessage, kind=ResultMessage.KIND_ROWS, results=[{'col': 'val'}]) response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
rf._set_result(response) rf._set_result(response)
result = rf.result() result = rf.result()
@@ -65,7 +68,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.send_request() rf.send_request()
result = Mock(spec=ResultMessage, result = Mock(spec=ResultMessage,
kind=ResultMessage.KIND_SET_KEYSPACE, kind=RESULT_KIND_SET_KEYSPACE,
results="keyspace1") results="keyspace1")
rf._set_result(result) rf._set_result(result)
rf._set_keyspace_completed({}) rf._set_keyspace_completed({})
@@ -77,7 +80,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.send_request() rf.send_request()
result = Mock(spec=ResultMessage, result = Mock(spec=ResultMessage,
kind=ResultMessage.KIND_SCHEMA_CHANGE, kind=RESULT_KIND_SCHEMA_CHANGE,
results={'keyspace': "keyspace1", "table": "table1"}) results={'keyspace': "keyspace1", "table": "table1"})
rf._set_result(result) rf._set_result(result)
session.submit.assert_called_once_with(ANY, 'keyspace1', 'table1', ANY, rf) session.submit.assert_called_once_with(ANY, 'keyspace1', 'table1', ANY, rf)
@@ -256,7 +259,7 @@ class ResponseFutureTests(unittest.TestCase):
rf = self.make_response_future(session) rf = self.make_response_future(session)
rf.send_request() rf.send_request()
response = Mock(spec=ResultMessage, kind=ResultMessage.KIND_ROWS, results=[{'col': 'val'}]) response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
rf._set_result(response) rf._set_result(response)
result = rf.result() result = rf.result()
@@ -277,7 +280,7 @@ class ResponseFutureTests(unittest.TestCase):
rf = self.make_response_future(session) rf = self.make_response_future(session)
rf.send_request() rf.send_request()
response = Mock(spec=ResultMessage, kind=ResultMessage.KIND_ROWS, results=[{'col': 'val'}]) response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
rf._set_result(response) rf._set_result(response)
self.assertEqual(rf.result(), [{'col': 'val'}]) self.assertEqual(rf.result(), [{'col': 'val'}])
@@ -291,7 +294,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.add_callback(self.assertEqual, [{'col': 'val'}]) rf.add_callback(self.assertEqual, [{'col': 'val'}])
response = Mock(spec=ResultMessage, kind=ResultMessage.KIND_ROWS, results=[{'col': 'val'}]) response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
rf._set_result(response) rf._set_result(response)
result = rf.result() result = rf.result()
@@ -346,7 +349,7 @@ class ResponseFutureTests(unittest.TestCase):
callback=self.assertEquals, callback_args=([{'col': 'val'}],), callback=self.assertEquals, callback_args=([{'col': 'val'}],),
errback=self.assertIsInstance, errback_args=(Exception,)) errback=self.assertIsInstance, errback_args=(Exception,))
response = Mock(spec=ResultMessage, kind=ResultMessage.KIND_ROWS, results=[{'col': 'val'}]) response = Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=[{'col': 'val'}])
rf._set_result(response) rf._set_result(response)
self.assertEqual(rf.result(), [{'col': 'val'}]) self.assertEqual(rf.result(), [{'col': 'val'}])